import torch

labels = torch.cat([torch.arange(16)+16 , torch.arange(16)])
print(labels)

inputs = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).float()


mask = torch.eye(3, dtype=bool)
sim_matrix = inputs.masked_fill_(mask, float("-inf"))
print(sim_matrix)
